import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from matplotlib import collections as matcoll
import os

# This file implements a plotting style using Matplotlib and seaborn,
# and actually makes plots of Local SGD variants' loss versus communication rounds

###################################################################################################
# Tweaking seaborn to make our curves more beautiful :)
# Seaborn allows us to actually change matplotlob parameters through it
# Inspired by: https://towardsdatascience.com/making-matplotlib-beautiful-by-default-d0d41e3534fd

sns.set(font='Franklin Gothic Book',
        rc={
            'axes.axisbelow': False,
            'axes.edgecolor': 'black',
            'axes.facecolor': 'None',
            'axes.grid': False,
            'axes.labelcolor': 'black',
            'axes.spines.right': False,
            'axes.spines.top': False,
            'figure.facecolor': 'white',
            'lines.solid_capstyle': 'round',
            'patch.edgecolor': 'w',
            'patch.force_edgecolor': True,
            'text.color': 'black',
            'xtick.bottom': False,
            'xtick.color': 'black',
            'xtick.direction': 'out',
            'xtick.top': False,
            'ytick.color': 'black',
            'ytick.direction': 'out',
            'ytick.left': False,
            'ytick.right': False})

# setting some global font sizes
sns.set_context("notebook", rc={"font.size": 15,
                                "axes.titlesize": 15,
                                "axes.labelsize": 15})

# Defining colour names
Blue = '#0000FF'
Light_Blue = '#add8e6'
Green = '#008000'
Light_Green = '#32CD32'
Pink = '#FFC0CB'
Purple = '#800080'
Light_Purple = '#CBC3E3'
Violet = '#8F00FF'
Light_Violet = '#CC99FF'
CB91_Amber = '#F5B14C'
Red = '#FF0000'
Light_Red = '#ffcccb'
Yellow = '#FFFF00'
Black = '#000000'
Gray = '#808080'
Orange = '#FFA500'
Light_Orange = '#FED8B1'

# Setting default colour for plotting and cycling through them
color_list = [Red, Light_Red, Black, Gray,
              Green, Light_Green, Purple, Light_Purple]
# color_list = [Red, Black, Gray,
#               Green, Purple]
# color_list = sns.color_palette("colorblind")
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=color_list)
plt.rcParams.update({'lines.markeredgewidth': 1})

#############################################################################
# T_values = [100, 1000]
T_values = [100]

for T in T_values:

    optimal_values = {"0.01": 0.3727237468639262, "0.0001": 0.32450692471375703,
                      "1e-06": 0.3226712387963549, "1e-08": 0.3226220624005153}

    if T == 1000:
        R_values = [1, 5, 10, 25, 50, 100, 500, 1000]
    elif T == 100:
        R_values = [1, 5, 10, 25, 50, 100]

    # mu_values = [1e-2, 1e-4, 1e-6, 1e-8]
    mu_values = [1e-4, 1e-6]
    # M_values = [20, 50, 100, 200, 500]
    M_values = [100, 200]
    x = np.arange(len(R_values[1:]))

    # fig, axs = plt.subplots(5, 4, figsize=(20, 16))
    fig, axs = plt.subplots(2, 2, figsize=(10, 8))

    i = 0
    for M in M_values:
        j = 0
        for mu in mu_values:
            optimal = optimal_values[str(mu)]
            Losses = np.load(
                f"results/{T}/Losses_{str(mu)}_{str(M)}.npy", allow_pickle=True).item()

            newton = (np.array(Losses["Newton"])[1:] - optimal)/optimal
            newton_w_M = (np.array(Losses["Newton_w_M"])[1:] - optimal)/optimal
            fedac1 = (np.array(Losses["FEDAC1"])[1:] - optimal)/optimal
            fedac2 = (np.array(Losses["FEDAC2"])[1:] - optimal)/optimal
            mbsgd = (np.array(Losses["MBSGD"])[1:] - optimal)/optimal
            mbsgd_w_M = (np.array(Losses["MBSGD_w_M"])[1:] - optimal)/optimal
            lsgd = (np.array(Losses["LSGD"])[1:] - optimal)/optimal
            lsgd_w_M = (np.array(Losses["LSGD_w_M"])[1:] - optimal)/optimal

            # Plotting Newton with Momentum
            y = np.log10(np.mean(newton_w_M, axis=1))
            y_err = np.log10(1 + np.std(newton_w_M, axis=1) /
                             np.mean(newton_w_M, axis=1))
            axs[i, j].errorbar(x, y, yerr=y_err, fmt='-o',
                               label="FedSN-Lite w/ Mom. (Our Method)", capsize=3, linewidth=2.5)

            # Plotting Newton
            y = np.log10(np.mean(newton, axis=1))
            y_err = np.log10(1 + np.std(newton, axis=1) /
                             np.mean(newton, axis=1))
            axs[i, j].errorbar(x, y, yerr=y_err, fmt='-o',
                               label="FedSN-Lite (Our Method)", capsize=3, linestyle="dotted", linewidth=1)

            # Plotting FEDAC-2
            y = np.log10(np.mean(fedac2, axis=1))
            y_err = np.log10(1 + np.std(fedac2, axis=1) /
                             np.mean(fedac2, axis=1))
            axs[i, j].errorbar(x, y, yerr=y_err, fmt='-o', label="FedAC-2 (w/ Knowledge of \u03bc)",
                               capsize=3, linewidth=2)

            # Plotting FEDAC-1
            y = np.log10(np.mean(fedac1, axis=1))
            y_err = np.log10(1 + np.std(fedac1, axis=1) /
                             np.mean(fedac1, axis=1))
            axs[i, j].errorbar(x, y, yerr=y_err, fmt='-o', label="FedAC-1 (w/ Knowledge of \u03bc)",
                               capsize=3, linewidth=2)

            # Plotting MBSGD with Momentum
            y = np.log10(np.mean(mbsgd_w_M, axis=1))
            y_err = np.log10(1 + np.std(mbsgd_w_M, axis=1) /
                             np.mean(mbsgd_w_M, axis=1))
            axs[i, j].errorbar(x, y, yerr=y_err, fmt='-o', label="MB SGD w/ Mom.",
                               capsize=3, linewidth=2)

            # Plotting MBSGD
            y = np.log10(np.mean(mbsgd, axis=1))
            y_err = np.log10(1 + np.std(mbsgd, axis=1)/np.mean(mbsgd, axis=1))
            axs[i, j].errorbar(x, y, yerr=y_err, fmt='-o', label="MB SGD",
                               capsize=3, linestyle="dotted", linewidth=1)

            # Plotting Local SGD with Momentum
            y = np.log10(np.mean(lsgd_w_M, axis=1))
            y_err = np.log10(1 + np.std(lsgd_w_M, axis=1) /
                             np.mean(lsgd_w_M, axis=1))
            axs[i, j].errorbar(x, y, yerr=y_err, fmt='-o', label="Local SGD w/ Mom.",
                               capsize=3, linewidth=2)

            # Plotting Local SGD
            y = np.log10(np.mean(lsgd, axis=1))
            y_err = np.log10(1 + np.std(lsgd, axis=1)/np.mean(lsgd, axis=1))
            axs[i, j].errorbar(x, y, yerr=y_err, fmt='-o', label="Local SGD",
                               capsize=3, linestyle="dotted", linewidth=1)

            # Adding labels on the axis
            if T == 100:
                axs[i, j].xaxis.set_ticklabels([1, 5, 10, 25, 50, 100])
            elif T == 1000:
                axs[i, j].xaxis.set_ticklabels(
                    [1, 5, 10, 25, 50, 100, 500, 1000])

            if i != 1:
                axs[i, j].tick_params(left=False, right=False, labelleft=True,
                                      labelbottom=False, bottom=False)
            j += 1
        i += 1

    # Labelling columns and rows
    cols = ["\u03bc="+str(mu) for mu in mu_values]
    rows = ["M="+str(M) for M in M_values]

    for ax, col in zip(axs[0], cols):
        ax.set_title(col)

    for ax, row in zip(axs[:, 0], rows):
        ax.set_ylabel(row, rotation=0, labelpad=70)

    plt.legend(loc='upper center', bbox_to_anchor=(
        -0.2, 2.7), fancybox=False, shadow=False, ncol=2, prop={'size': 15})

    plt.subplots_adjust(wspace=0.20,
                        hspace=0.05)

    fig.text(0.5, 0.03, 'Number of Communication Rounds', ha='center')
    fig.text(0.04, 0.5, 'Best Log(Sub-optimality)',
             va='center', rotation='vertical')

    # Saving the figures
    plt.savefig(f"figures/{T}/plot_w_rep_small.png", dpi=500,
                orientation="portrait", bbox_inches='tight')
